{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 5-8 回调函数callbacks\n",
    "\n",
    "tf.keras的回调函数实际上是一个类，一般是在model.fit时作为参数指定，用于控制在训练过程开始或者在训练过程结束，在每个epoch训练开始或者训练结束，在每个batch训练开始或者训练结束时执行一些操作，例如收集一些日志信息，改变学习率等超参数，提前终止训练过程等等。\n",
    "\n",
    "同样地，针对model.evaluate或者model.predict也可以指定callbacks参数，用于控制在评估或预测开始或者结束时，在每个batch开始或者结束时执行一些操作，但这种用法相对少见。\n",
    "\n",
    "大部分时候，keras.callbacks子模块中定义的回调函数类已经足够使用了，如果有特定的需要，我们也可以通过对keras.callbacks.Callbacks实施子类化构造自定义的回调函数。\n",
    "\n",
    "所有回调函数都继承至 keras.callbacks.Callbacks基类，拥有params和model这两个属性。\n",
    "\n",
    "其中params 是一个dict，记录了 training parameters (eg. verbosity, batch size, number of epochs...).\n",
    "\n",
    "model即当前关联的模型的引用。\n",
    "\n",
    "此外，对于回调类中的一些方法如on_epoch_begin,on_batch_end，还会有一个输入参数logs, 提供有关当前epoch或者batch的一些信息，并能够记录计算结果，如果model.fit指定了多个回调函数类，这些logs变量将在这些回调函数类的同名函数间依顺序传递。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 一、内置回调函数"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "* BaseLogger： 收集每个epoch上metrics在各个batch上的平均值，对stateful_metrics参数中的带中间状态的指标直接拿最终值无需对各个batch平均，指标均值结果将添加到logs变量中。该回调函数被所有模型默认添加，且是第一个被添加的。\n",
    "\n",
    "* History： 将BaseLogger计算的各个epoch的metrics结果记录到history这个dict变量中，并作为model.fit的返回值。该回调函数被所有模型默认添加，在BaseLogger之后被添加。\n",
    "\n",
    "* EarlyStopping： 当被监控指标在设定的若干个epoch后没有提升，则提前终止训练。\n",
    "\n",
    "* TensorBoard： 为Tensorboard可视化保存日志信息。支持评估指标，计算图，模型参数等的可视化。\n",
    "\n",
    "* ModelCheckpoint： 在每个epoch后保存模型。\n",
    "\n",
    "* ReduceLROnPlateau：如果监控指标在设定的若干个epoch后没有提升，则以一定的因子减少学习率。\n",
    "\n",
    "* TerminateOnNaN：如果遇到loss为NaN，提前终止训练。\n",
    "\n",
    "* LearningRateScheduler：学习率控制器。给定学习率lr和epoch的函数关系，根据该函数关系在每个epoch前调整学习率。\n",
    "\n",
    "* CSVLogger：将每个epoch后的logs结果记录到CSV文件中。\n",
    "\n",
    "* ProgbarLogger：将每个epoch后的logs结果打印到标准输出流中。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 二、自定义回调函数\n",
    "\n",
    "可以使用callbacks.LambdaCallback编写较为简单的回调函数，也可以通过对callbacks.Callback子类化编写更加复杂的回调函数逻辑。\n",
    "\n",
    "如果需要深入学习tf.Keras中的回调函数，不要犹豫阅读内置回调函数的源代码。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import tensorflow as tf\n",
    "from tensorflow.keras import layers,models,losses,metrics,callbacks\n",
    "import tensorflow.keras.backend as K "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 示范使用LambdaCallback编写较为简单的回调函数\n",
    "\n",
    "import json\n",
    "json_log = open('./data/keras_log.json', mode='wt', buffering=1)\n",
    "json_logging_callback = callbacks.LambdaCallback(\n",
    "    on_epoch_end=lambda epoch, logs: json_log.write(\n",
    "        json.dumps(dict(epoch = epoch,**logs)) + '\\n'),\n",
    "    on_train_end=lambda logs: json_log.close()\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 示范通过Callback子类化编写回调函数（LearningRateScheduler的源代码）\n",
    "\n",
    "class LearningRateScheduler(callbacks.Callback):\n",
    "    \n",
    "    def __init__(self, schedule, verbose=0):\n",
    "        super(LearningRateScheduler, self).__init__()\n",
    "        self.schedule = schedule\n",
    "        self.verbose = verbose\n",
    "\n",
    "    def on_epoch_begin(self, epoch, logs=None):\n",
    "        if not hasattr(self.model.optimizer, 'lr'):\n",
    "            raise ValueError('Optimizer must have a \"lr\" attribute.')\n",
    "        try:  \n",
    "            lr = float(K.get_value(self.model.optimizer.lr))\n",
    "            lr = self.schedule(epoch, lr)\n",
    "        except TypeError:  # Support for old API for backward compatibility\n",
    "            lr = self.schedule(epoch)\n",
    "        if not isinstance(lr, (tf.Tensor, float, np.float32, np.float64)):\n",
    "            raise ValueError('The output of the \"schedule\" function '\n",
    "                             'should be float.')\n",
    "        if isinstance(lr, ops.Tensor) and not lr.dtype.is_floating:\n",
    "            raise ValueError('The dtype of Tensor should be float')\n",
    "        K.set_value(self.model.optimizer.lr, K.get_value(lr))\n",
    "        if self.verbose > 0:\n",
    "            print('\\nEpoch %05d: LearningRateScheduler reducing learning '\n",
    "                 'rate to %s.' % (epoch + 1, lr))\n",
    "\n",
    "    def on_epoch_end(self, epoch, logs=None):\n",
    "        logs = logs or {}\n",
    "        logs['lr'] = K.get_value(self.model.optimizer.lr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
